[BAEKJOON] 11437. LCA
Posted on
문제 :
N(2 ≤ N ≤ 50,000)개의 정점으로 이루어진 트리가 주어진다. 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.
두 노드의 쌍 M(1 ≤ M ≤ 10,000)개가 주어졌을 때, 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력한다.
입력 :
첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정점 쌍이 주어진다.
출력 :
M개의 줄에 차례대로 입력받은 두 정점의 가장 가까운 공통 조상을 출력한다.
풀이 :
원래는 이 문제가 아닌 다른 그래프 이론 문제를 풀다가 LCA 알고리즘을 활용해야 한다는 것을 알게 된 후 차근차근 LCA부터 공부해보자라는 취지로 문제를 풀어보았다.
- LCA (Lowest Common Ancester) Algorithm
- LCA 알고리즘은 트리에서 두 노드의 최소 공통 부모 노드를 찾는 알고리즘을 말한다. 알고리즘의 경우 로직이 크게 어려운 것은 없으나 만약 LCA 알고리즘에 대해 전혀 모른다면 문제를 조금만 변형하더라도 어렵게 다가올 수 있겠다 싶은 생각이 들었다.
- LCA 알고리즘은 우선 가장 핵심 포인트는 두 노드의 depth 중 더 낮은 depth 의 노드로 깊이를 맞춰줘야 한다는 점이다.
- 가령, 1번 노드가 depth가 4이고, 2번 노드가 depth가 6일 경우 2번 노드를 depth가 4가 될 때까지 부모 노드로 타고 올라와서 계산해주어야 한다.
- depth를 같게 맞춰줬을 때부턴 차례로 같은 depth의 노드가 동일한지 검사해보면 된다.
난 본 문제의 테스트케이스가 꽤 크다 생각하여 dp 방식으로 각 노드별 부모 노드를 모두 저장하여 계산해 주려 했다.
하지만 본 문제는 dp를 굳이 사용하지 않아도 풀 수 있었으며 오히려 dp 를 사용하게 되면 메모리 초과 결과가 출력되었다.
코드 :
코드 보기/접기
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <utility>
#define pii pair<int, int>
using namespace std;
vector<vector<int>> adj;
int *parent, *depth;
void bfs(int n) {
bool *visit = new bool[n + 1]{false};
queue<pii > q;
q.push(pii(1, 1));
depth[1] = 1;
while (!q.empty()) {
int curidx = q.front().first, curdepth = q.front().second, size = adj[curidx].size();
q.pop();
visit[curidx] = true;
for (int k = 0; k < size; k++) {
int cmp = adj[curidx][k];
if (visit[cmp]) continue;
parent[cmp] = curidx;
depth[cmp] = curdepth + 1;
q.push(pii(cmp, curdepth + 1));
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int n, m, i, fir, sec, idx;
cin >> n;
parent = new int[n + 1];
depth = new int[n + 1];
adj.resize(n + 1);
for (i = 1; i < n; i++) {
cin >> fir >> sec;
adj[fir].push_back(sec);
adj[sec].push_back(fir);
}
bfs(n);
cin >> m;
while (m--) {
cin >> fir >> sec;
idx = min(depth[fir], depth[sec]);
if (depth[fir] != idx)
for (int i = depth[fir]; i > idx; i--)
fir = parent[fir];
else if (depth[sec] != idx)
for (int i = depth[sec]; i > idx; i--)
sec = parent[sec];
for (; idx >= 0; idx--) {
if (fir == sec) {
cout << fir << '\n';
break;
}
fir = parent[fir];
sec = parent[sec];
}
}
return 0;
}